{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "from captcha.image import ImageCaptcha\n",
    "import random\n",
    "import string\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "import torchvision\n",
    "from torch.autograd import Variable\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import Dataset\n",
    "from torchvision.transforms import Compose, ToTensor\n",
    "from torch.utils.data import DataLoader,TensorDataset\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculat_acc(output, target):\n",
    "    output, target = output.view(-1,10,4), target.view(-1, 4)\n",
    "#     output = nn.functional.softmax(output, dim=1)\n",
    "    output = torch.argmax(output, dim=1)\n",
    "    output = output.float()\n",
    "    \n",
    "#     target = torch.argmax(target, dim=1)\n",
    "    output, target = output.view(-1, 4), target.view(-1, 4)\n",
    "    correct_list = []\n",
    "    for i, j in zip(target, output):\n",
    "        if torch.equal(i, j):\n",
    "            correct_list.append(1)\n",
    "        else:\n",
    "            correct_list.append(0)\n",
    "    acc = sum(correct_list) / len(correct_list)\n",
    "    return acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#定义验证码大小，n_class,\n",
    "height,width,batch_size,n_class = 50,100,32,10\n",
    "#生成随机数\n",
    "def get_code():\n",
    "    raw = string.digits\n",
    "#     + string.ascii_lowercase #小写\n",
    "#     + string.ascii_uppercase#大写\n",
    "\n",
    "    random_code = ''.join(random.sample(raw,4))\n",
    "    return raw,random_code\n",
    "\n",
    "#生成验证码\n",
    "def generator(height,width,batch_size,n_class):\n",
    "    X = np.zeros((batch_size,height,width,3))\n",
    "    y = np.zeros((batch_size,4))\n",
    "    \n",
    "    #验证码生成器\n",
    "    generator = ImageCaptcha(height=height,width=width)\n",
    "    for i in range(batch_size):\n",
    "        raw ,random_code = get_code()\n",
    "        img = np.array((generator.generate_image(random_code)),dtype=np.float32)\n",
    "        X[i] = img/255.0\n",
    "\n",
    "        for step,ch in enumerate(random_code):\n",
    "            y[i][step] = int(ch)\n",
    "    return X,y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "train_batch_size = 1000\n",
    "test_batch_size = 128\n",
    "train_data,train_label = generator(height,width,train_batch_size,n_class)\n",
    "train_data = np.transpose(train_data,(0,3,1,2))\n",
    "train_data = torch.from_numpy(train_data)\n",
    "train_label = torch.Tensor(train_label)\n",
    "train_datas = TensorDataset(train_data,train_label)\n",
    "\n",
    "test_data,test_label = generator(height,width,test_batch_size,n_class)\n",
    "test_data = np.transpose(test_data,(0,3,1,2))\n",
    "test_data = torch.from_numpy(test_data)\n",
    "test_label = torch.Tensor(test_label)\n",
    "test_datas = TensorDataset(test_data,test_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "train_data_loader = DataLoader(train_datas, batch_size=batch_size,num_workers=0, shuffle=True, drop_last=True)\n",
    "\n",
    "test_data_loader = DataLoader(test_datas, batch_size=batch_size,num_workers=0, shuffle=True, drop_last=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i,j in train_data_loader:\n",
    "    print(j.shape)\n",
    "#     j = j.view(-1,4,4)\n",
    "#     print(j.view(-1,4,4).shape)\n",
    "#     print(np.argmax(j,axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self, num_class=10, num_char=4):\n",
    "        super(CNN, self).__init__()\n",
    "        self.num_class = num_class\n",
    "        self.num_char = num_char\n",
    "        self.conv = nn.Sequential(\n",
    "                nn.Conv2d(3, 16, 3,stride=(1,1), padding=(2,2)),\n",
    "                nn.MaxPool2d(2, 2),\n",
    "                nn.ReLU(),\n",
    "                nn.Conv2d(16, 32, 3,stride=(1,2), padding=(1,1)),\n",
    "                nn.MaxPool2d(2, 2),\n",
    "                nn.ReLU(),\n",
    "                nn.Conv2d(32, 64, 3, padding=(1,1)),\n",
    "                nn.MaxPool2d(2, 2),\n",
    "                nn.Conv2d(64, 10, 3,stride=(2,2)),\n",
    "                )\n",
    "        self.act = nn.LogSoftmax(dim=1)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = x.view(-1,10,4)#torch.Size([128, 10, 4])\n",
    "#         print(x.shape)\n",
    "        x = self.act(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_loss: 2.305|train_acc: 0.0\n",
      "test_loss: 2.303|test_acc: 0.0\n",
      "train_loss: 2.302|train_acc: 0.0\n",
      "test_loss: 2.301|test_acc: 0.0\n",
      "train_loss: 2.3|train_acc: 0.0\n",
      "test_loss: 2.295|test_acc: 0.0\n",
      "train_loss: 2.294|train_acc: 0.0\n",
      "test_loss: 2.279|test_acc: 0.0\n",
      "train_loss: 2.265|train_acc: 0.001008\n",
      "test_loss: 2.229|test_acc: 0.0\n",
      "train_loss: 2.205|train_acc: 0.001008\n",
      "test_loss: 2.172|test_acc: 0.0\n",
      "train_loss: 2.142|train_acc: 0.002016\n",
      "test_loss: 2.108|test_acc: 0.0\n",
      "train_loss: 2.073|train_acc: 0.00504\n",
      "test_loss: 2.037|test_acc: 0.007812\n",
      "train_loss: 1.993|train_acc: 0.01411\n",
      "test_loss: 1.963|test_acc: 0.007812\n",
      "train_loss: 1.908|train_acc: 0.01109\n",
      "test_loss: 1.87|test_acc: 0.01562\n",
      "train_loss: 1.816|train_acc: 0.01714\n",
      "test_loss: 1.794|test_acc: 0.02344\n",
      "train_loss: 1.741|train_acc: 0.02319\n",
      "test_loss: 1.739|test_acc: 0.03125\n",
      "train_loss: 1.66|train_acc: 0.03629\n",
      "test_loss: 1.633|test_acc: 0.0625\n",
      "train_loss: 1.586|train_acc: 0.03931\n",
      "test_loss: 1.574|test_acc: 0.05469\n",
      "train_loss: 1.516|train_acc: 0.05948\n",
      "test_loss: 1.502|test_acc: 0.1016\n",
      "train_loss: 1.452|train_acc: 0.07157\n",
      "test_loss: 1.451|test_acc: 0.1328\n",
      "train_loss: 1.384|train_acc: 0.08266\n",
      "test_loss: 1.388|test_acc: 0.1484\n",
      "train_loss: 1.329|train_acc: 0.1018\n",
      "test_loss: 1.326|test_acc: 0.1328\n",
      "train_loss: 1.267|train_acc: 0.124\n",
      "test_loss: 1.253|test_acc: 0.1641\n",
      "train_loss: 1.223|train_acc: 0.1321\n",
      "test_loss: 1.226|test_acc: 0.1875\n",
      "train_loss: 1.165|train_acc: 0.1643\n",
      "test_loss: 1.159|test_acc: 0.2031\n",
      "train_loss: 1.11|train_acc: 0.1966\n",
      "test_loss: 1.107|test_acc: 0.2188\n",
      "train_loss: 1.053|train_acc: 0.2036\n",
      "test_loss: 1.069|test_acc: 0.2266\n",
      "train_loss: 0.9989|train_acc: 0.2208\n",
      "test_loss: 1.02|test_acc: 0.2812\n",
      "train_loss: 0.9634|train_acc: 0.255\n",
      "test_loss: 0.9495|test_acc: 0.2891\n",
      "train_loss: 0.9022|train_acc: 0.2883\n",
      "test_loss: 0.9209|test_acc: 0.3203\n",
      "train_loss: 0.8626|train_acc: 0.3206\n",
      "test_loss: 0.888|test_acc: 0.3359\n",
      "train_loss: 0.825|train_acc: 0.3347\n",
      "test_loss: 0.8299|test_acc: 0.375\n",
      "train_loss: 0.7848|train_acc: 0.3468\n",
      "test_loss: 0.7994|test_acc: 0.3281\n",
      "train_loss: 0.7364|train_acc: 0.379\n",
      "test_loss: 0.7533|test_acc: 0.3906\n",
      "train_loss: 0.7011|train_acc: 0.4062\n",
      "test_loss: 0.7016|test_acc: 0.4297\n",
      "train_loss: 0.6547|train_acc: 0.4395\n",
      "test_loss: 0.6465|test_acc: 0.4375\n",
      "train_loss: 0.6128|train_acc: 0.4829\n",
      "test_loss: 0.6428|test_acc: 0.4297\n",
      "train_loss: 0.5798|train_acc: 0.5071\n",
      "test_loss: 0.5585|test_acc: 0.5234\n",
      "train_loss: 0.538|train_acc: 0.5504\n",
      "test_loss: 0.5346|test_acc: 0.5703\n",
      "train_loss: 0.4999|train_acc: 0.5746\n",
      "test_loss: 0.4996|test_acc: 0.6016\n",
      "train_loss: 0.4793|train_acc: 0.5887\n",
      "test_loss: 0.4986|test_acc: 0.5391\n",
      "train_loss: 0.4573|train_acc: 0.6069\n",
      "test_loss: 0.4437|test_acc: 0.6484\n",
      "train_loss: 0.4295|train_acc: 0.6341\n",
      "test_loss: 0.4121|test_acc: 0.6641\n",
      "train_loss: 0.393|train_acc: 0.6694\n",
      "test_loss: 0.4|test_acc: 0.6094\n",
      "train_loss: 0.3673|train_acc: 0.6825\n",
      "test_loss: 0.3409|test_acc: 0.7109\n",
      "train_loss: 0.3412|train_acc: 0.7248\n",
      "test_loss: 0.3227|test_acc: 0.7109\n",
      "train_loss: 0.3096|train_acc: 0.754\n",
      "test_loss: 0.2779|test_acc: 0.7891\n",
      "train_loss: 0.2783|train_acc: 0.7974\n",
      "test_loss: 0.2913|test_acc: 0.6953\n",
      "train_loss: 0.2692|train_acc: 0.7964\n",
      "test_loss: 0.2388|test_acc: 0.7891\n",
      "train_loss: 0.241|train_acc: 0.8357\n",
      "test_loss: 0.2215|test_acc: 0.8516\n",
      "train_loss: 0.2172|train_acc: 0.8679\n",
      "test_loss: 0.1911|test_acc: 0.8828\n",
      "train_loss: 0.1975|train_acc: 0.872\n",
      "test_loss: 0.1701|test_acc: 0.9141\n",
      "train_loss: 0.1793|train_acc: 0.8962\n",
      "test_loss: 0.1642|test_acc: 0.9141\n",
      "train_loss: 0.1603|train_acc: 0.9214\n",
      "test_loss: 0.1354|test_acc: 0.9453\n",
      "train_loss: 0.1366|train_acc: 0.9446\n",
      "test_loss: 0.1233|test_acc: 0.9531\n",
      "train_loss: 0.1316|train_acc: 0.9556\n",
      "test_loss: 0.1228|test_acc: 0.9375\n",
      "train_loss: 0.1196|train_acc: 0.9536\n",
      "test_loss: 0.1058|test_acc: 0.9922\n",
      "train_loss: 0.1078|train_acc: 0.9688\n",
      "test_loss: 0.09284|test_acc: 0.9766\n",
      "train_loss: 0.09377|train_acc: 0.9808\n",
      "test_loss: 0.0825|test_acc: 0.9922\n",
      "train_loss: 0.08176|train_acc: 0.9859\n",
      "test_loss: 0.07425|test_acc: 0.9922\n",
      "train_loss: 0.07315|train_acc: 0.9909\n",
      "test_loss: 0.06361|test_acc: 0.9844\n",
      "train_loss: 0.06906|train_acc: 0.9919\n",
      "test_loss: 0.06138|test_acc: 0.9922\n",
      "train_loss: 0.05947|train_acc: 0.998\n",
      "test_loss: 0.05061|test_acc: 1.0\n",
      "train_loss: 0.05186|train_acc: 0.999\n",
      "test_loss: 0.0443|test_acc: 1.0\n",
      "train_loss: 0.04518|train_acc: 1.0\n",
      "test_loss: 0.04136|test_acc: 1.0\n",
      "train_loss: 0.04069|train_acc: 0.999\n",
      "test_loss: 0.03789|test_acc: 1.0\n",
      "train_loss: 0.03789|train_acc: 1.0\n",
      "test_loss: 0.0341|test_acc: 1.0\n",
      "train_loss: 0.03374|train_acc: 1.0\n",
      "test_loss: 0.02882|test_acc: 1.0\n",
      "train_loss: 0.03075|train_acc: 1.0\n",
      "test_loss: 0.02755|test_acc: 1.0\n",
      "train_loss: 0.02878|train_acc: 1.0\n",
      "test_loss: 0.02604|test_acc: 1.0\n",
      "train_loss: 0.02636|train_acc: 1.0\n",
      "test_loss: 0.02447|test_acc: 1.0\n",
      "train_loss: 0.02398|train_acc: 1.0\n",
      "test_loss: 0.02391|test_acc: 1.0\n",
      "train_loss: 0.02231|train_acc: 1.0\n",
      "test_loss: 0.01991|test_acc: 1.0\n",
      "train_loss: 0.02061|train_acc: 1.0\n",
      "test_loss: 0.01994|test_acc: 1.0\n",
      "train_loss: 0.02004|train_acc: 1.0\n",
      "test_loss: 0.01734|test_acc: 1.0\n",
      "train_loss: 0.0188|train_acc: 1.0\n",
      "test_loss: 0.01659|test_acc: 1.0\n",
      "train_loss: 0.01717|train_acc: 1.0\n",
      "test_loss: 0.01579|test_acc: 1.0\n",
      "train_loss: 0.01557|train_acc: 1.0\n",
      "test_loss: 0.01378|test_acc: 1.0\n",
      "train_loss: 0.01453|train_acc: 1.0\n",
      "test_loss: 0.01313|test_acc: 1.0\n",
      "train_loss: 0.01401|train_acc: 1.0\n",
      "test_loss: 0.013|test_acc: 1.0\n",
      "train_loss: 0.01367|train_acc: 1.0\n",
      "test_loss: 0.01293|test_acc: 1.0\n",
      "train_loss: 0.01258|train_acc: 1.0\n",
      "test_loss: 0.01163|test_acc: 1.0\n",
      "train_loss: 0.01184|train_acc: 1.0\n",
      "test_loss: 0.01025|test_acc: 1.0\n",
      "train_loss: 0.01109|train_acc: 1.0\n",
      "test_loss: 0.01007|test_acc: 1.0\n"
     ]
    }
   ],
   "source": [
    "cnn = CNN()\n",
    "if torch.cuda.is_available():\n",
    "    cnn.cuda()\n",
    "optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)\n",
    "criterion = nn.NLLLoss(reduction='mean')\n",
    "\n",
    "for epoch in range(80):\n",
    "    cnn.train()\n",
    "    loss_history = []\n",
    "    acc_history = []\n",
    "    for img, target in train_data_loader:\n",
    "        output = cnn(img.float())\n",
    "        loss = criterion(output, target.long())\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        acc = calculat_acc(output, target)\n",
    "        acc_history.append(float(acc))\n",
    "        loss_history.append(float(loss))\n",
    "    print('train_loss: {:.4}|train_acc: {:.4}'.format(\n",
    "            torch.mean(torch.Tensor(loss_history)),\n",
    "            torch.mean(torch.Tensor(acc_history))))\n",
    "\n",
    "    loss_history = []\n",
    "    acc_history = []\n",
    "#     cnn.eval()\n",
    "    for img, target in test_data_loader:\n",
    "        output = cnn(img.float())#torch.Size([128, 10, 4])\n",
    "        loss = criterion(output, target.long())\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        acc = calculat_acc(output, target)\n",
    "        acc_history.append(float(acc))\n",
    "        loss_history.append(float(loss))\n",
    "        \n",
    "    print('test_loss: {:.4}|test_acc: {:.4}'.format(\n",
    "            torch.mean(torch.Tensor(loss_history)),\n",
    "            torch.mean(torch.Tensor(acc_history))\n",
    "            ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(cnn.state_dict(), 'params.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[4, 1, 5, 8]])\n",
      "tensor([4., 1., 5., 8.])\n"
     ]
    }
   ],
   "source": [
    "for img,target in test_data_loader:\n",
    "#     print(target[1].shape)\n",
    "    cnn.eval()\n",
    "    img = img[1]\n",
    "    target = target[1]\n",
    "    img = img.unsqueeze_(dim=0)\n",
    "\n",
    "output = cnn(img.float())\n",
    "out = torch.argmax(output, dim=1)\n",
    "print(out)\n",
    "print(target)\n",
    "#     loss = criterion(output, target.long())\n",
    "#     optimizer.zero_grad()\n",
    "#     loss.backward()\n",
    "#     optimizer.step()\n",
    "\n",
    "#     acc = calculat_acc(output, target)\n",
    "#     acc_history.append(float(acc))\n",
    "#     loss_history.append(float(loss))\n",
    "# print('train_loss: {:.4}|train_acc: {:.4}'.format(\n",
    "#         torch.mean(torch.Tensor(loss_history)),\n",
    "#         torch.mean(torch.Tensor(acc_history))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
